{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diffusion from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Diffusion corrupts data and learns to reconstruct it. This allows it to learn distributions in the data so that with new randomness, it can enforce the distributions it learned to generate new data. \n",
    "\n",
    "## Forward Process\n",
    "\n",
    "The forward process is a Markov chain where each step adds noise to the data.\n",
    "\n",
    "For a given data point $x_0$ from the data distribution $p$, at each step $t$ from $0$ to $T$, the noise added is defined by:\n",
    "\n",
    "$x_t = \\sqrt{1 - \\beta_t} \\cdot x_{t-1} + \\sqrt{\\beta_t} \\cdot \\epsilon$\n",
    "\n",
    "where:\n",
    "- $\\beta_t$ is a predefined noise schedule.\n",
    "- $\\epsilon$ is Gaussian noise.\n",
    "\n",
    "### Simplified Form:\n",
    "\n",
    "$x_t = \\alpha_t \\cdot x_0 + \\sigma_t \\cdot \\epsilon$\n",
    "\n",
    "where $\\alpha_t$ and $\\sigma_t$ are computed directly from the noise schedule.\n",
    "\n",
    "---\n",
    "\n",
    "## Reparametrization\n",
    "\n",
    "We use a reparametrization technique to rewrite the equation w.r.t $x_0$. This allows us to calculate the noise at timestep $t$ relative to timestep $0$, rather than relying on $t-1$.\n",
    "\n",
    "This is useful because, in the reverse process, the neural network learns to predict the noise that was added to each timestep in the forward process, to reconstruct the original data **without sequential dependency** on each prior timestep $t-1$. This reduces complexity by not requiring the computation of intermediate steps.\n",
    "\n",
    "### Reparametrized Equation:\n",
    "\n",
    "$x_t = \\alpha_t \\cdot x_0 + \\sigma_t \\cdot \\epsilon$\n",
    "\n",
    "where $\\epsilon$ is the noise the model learns to predict.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from torch import optim\n",
    "from matplotlib import pyplot as plt\n",
    "import logging\n",
    "import numpy as np\n",
    "\n",
    "logging.basicConfig(format=\"%(asctime)s - %(levelname)s: %(message)s\", level=logging.INFO, datefmt=\"%I:%M:%S\")\n",
    "\n",
    "# the following values are taken from the original diffusion paper: noise_steps=1000, beta_start=1e-4, beta_end=0.02 \n",
    "class Diffusion:\n",
    "    def __init__(self, total_timesteps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device=\"cpu\"):\n",
    "        self.total_timesteps = total_timesteps\n",
    "        self.beta_start = beta_start\n",
    "        self.beta_end = beta_end\n",
    "        self.img_size = img_size\n",
    "        self.device = device\n",
    "\n",
    "        self.beta = self.prepare_noise_schedule().to(device)\n",
    "        self.alpha = 1. - self.beta\n",
    "        self.alpha_hat = torch.cumprod(self.alpha, dim=0)\n",
    "\n",
    "    def prepare_noise_schedule(self):\n",
    "        # linear scheduler (note: cosine scheduler is recommended due to smoother noise injection)\n",
    "        return torch.linspace(self.beta_start, self.beta_end, self.total_timesteps)     # create a 1d tensor of evenly spaced values between two end points\n",
    "        \n",
    "    #  x_t = √(α̅ₜ) * x₀ + √(1 - α̅ₜ) * ε\n",
    "    def noise_images(self, x, t):\n",
    "        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]                     # alpha hat is the culumative product of alphas up to timestep t\n",
    "        sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, None, None]      # [:, None, None, None] reshapes the tensor to match the image dimensions\n",
    "        epsilon = torch.randn_like(x)\n",
    "        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon\n",
    "\n",
    "    def sample_timesteps(self, n):\n",
    "        return torch.randint(low=1, high=self.noise_steps, size=(n,))\n",
    "    \n",
    "    def sample(self, model, n, labels, cfg_scale=3):\n",
    "        logging.info(f\"Sampling {n} new images....\")\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)       # initialize x as pure gaussian noise (x_t at most noisy state)\n",
    "            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):            # reverse diffusion process (iterate backwards through timsteps t-1 to 1)\n",
    "                t = (torch.ones(n) * i).long().to(self.device)      \n",
    "                predicted_noise = model(x, t, labels)\n",
    "                if cfg_scale > 0:\n",
    "                    uncond_predicted_noise = model(x, t, None)\n",
    "                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)\n",
    "                alpha = self.alpha[t][:, None, None, None]\n",
    "                alpha_hat = self.alpha_hat[t][:, None, None, None]\n",
    "                beta = self.beta[t][:, None, None, None]\n",
    "                if i > 1:\n",
    "                    noise = torch.randn_like(x)\n",
    "                else:\n",
    "                    noise = torch.zeros_like(x)\n",
    "                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise\n",
    "        model.train()\n",
    "        x = (x.clamp(-1, 1) + 1) / 2\n",
    "        x = (x * 255).type(torch.uint8)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recommended Resources\n",
    "\n",
    "- [Diffusion Models Explained](https://www.youtube.com/watch?v=fbJac4qQy04&ab_channel=ComputerVisionwithH%C3%BCseyin%C3%96zdemir)\n",
    "- [Diffusion Models Implementation](https://www.youtube.com/watch?v=TBCRlnwJtZU&ab_channel=Outlier)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ML-from-scratch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
